{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n1YlS56MXcaG"
      },
      "outputs": [],
      "source": [
        "# Copyright 2023 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "![ -d sparse_soft_topk ] || git clone https://github.com/google-research/sparse_soft_topk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vshQNNg9iR7X"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import sparse_soft_topk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Z6PHVUQURCg"
      },
      "outputs": [],
      "source": [
        "# We define the hard top-k operator\n",
        "def hard_topk_mask(s, k):\n",
        "  zero = jnp.zeros(s.shape)\n",
        "  return jnp.where(s \u003c= s.sort()[::-1][k], zero, jnp.ones(s.shape))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A8VtLyoSBN6y"
      },
      "outputs": [],
      "source": [
        "# Defining the input array\n",
        "n_discretization = 500\n",
        "x = jnp.linspace(0, 5, n_discretization)\n",
        "s = (jnp.array([3 * jnp.ones(n_discretization), jnp.ones(n_discretization), -1 + x, x])).T"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o1S-QtO-A_KY"
      },
      "outputs": [],
      "source": [
        "# Regularization scale\n",
        "l = 6e-1\n",
        "# k\n",
        "k = 2\n",
        "# Hard top-k\n",
        "hards = jax.vmap(hard_topk_mask, in_axes=(0, None))(s, k).T \n",
        "# Sparse and smooth top-k with l2 regularization\n",
        "softs_2 = sparse_soft_topk.sparse_soft_topk_mask_pav(s, k=k, l=l, p=2.0).T\n",
        "# Sparse and smooth top-k with l-(4/3) regularization\n",
        "softs_4 = sparse_soft_topk.sparse_soft_topk_mask_pav(s, k=k, l=l, p=4/3).T"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k4FLLrUeARlO"
      },
      "outputs": [],
      "source": [
        "# Plotting the outputs\n",
        "p = 0.9\n",
        "plt.figure(figsize=(4 * p, 3 * p))\n",
        "plt.plot(x, hards[1] + hards[2],  \"--\",lw=3, label=\"Hard\")\n",
        "plt.plot(x, softs_2[1] + softs_2[2], lw=3, label=\"$p = 2$\")\n",
        "plt.plot(x, softs_4[1] + softs_4[2], lw=3, label=\"$p = 4/3$\", alpha=0.6)\n",
        "\n",
        "\n",
        "plt.xlabel(r\"$s$\")\n",
        "plt.ylabel(r\"$\\mathrm{topkmask}(\\theta(s))_2 + \\mathrm{topkmask}(\\theta(s))_3$\")\n",
        "plt.legend(loc=\"upper center\")\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QtlHlbidXHga"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "11Twk6kQ9D5BId4UU6h--AWyLX3zPPvUe",
          "timestamp": 1677587327795
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
